"""
#
# Flow stability for dynamic community detection https://arxiv.org/abs/2101.06131v2
#
# Copyright (C) 2021 Alexandre Bovet <alexandre.bovet@maths.ox.ac.uk>
#
# This program is free software; you can redistribute it and/or modify it under
# the terms of the GNU Lesser General Public License as published by the Free
# Software Foundation; either version 3 of the License, or (at your option) any
# later version.
#
# This program is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
# details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.


This script adds country affiliations to the edges dataframe.

results is saved as `all_journals_disamb_edges_countries.csv.gz`

"""


import os
import pandas as pd
import numpy as np
from multiprocessing import Pool
import time
import json

import jellyfish 

savedir = '../paper_data/aps/doi_country'
nproc = 6

raise Exception
#%% now add country infos to edges (i.e. papers)

# ID, name ,num_article
# this is the author name disambiguation of the APS dataset
# available at Supplementary Material at https://doi.org/10.1126/science.aaf5239
df_authors = pd.read_csv('../data/aps/APS_authors.csv', index_col=0)

# load edge_list (papers with from 2 to 10 authors)
df_edges = pd.read_csv('../data/aps/all_journals_disamb_edges.csv.gz', index_col=0)

with open('../data/aps/doi_to_affilitations_authors_dict.json', 'r') as fopen:
    
    doi_to_affilitations = json.load(fopen)
    

with open('../data/aps/affiliations_to_countries.json','w') as fopen:
    final_affil_to_country = json.read(fopen)

#%%

num_dois = df_edges.doi.unique().size


def worker(doi_group):
    doi, df_doi = doi_group

    if len(doi_to_affilitations[doi]['affiliations']) == 1:
        # same affilitation for all authors
        aff = doi_to_affilitations[doi]['affiliations'][0]['name']
        if aff in final_affil_to_country:
            df_doi['n1_country'] = \
                ','.join(final_affil_to_country[aff])
            df_doi['n2_country'] = \
                ','.join(final_affil_to_country[aff])
            
    if len(doi_to_affilitations[doi]['affiliations']) > 1:
        # several affilitations
        # check if different countries:
        countries = [','.join(final_affil_to_country.get(aff['name'],'')) for aff \
                         in doi_to_affilitations[doi]['affiliations']]
        
        if len(set(countries)) == 1:
            # same country
            df_doi['n1_country'] = countries[0]
            df_doi['n2_country'] = countries[0]
            
        elif len(set(countries)) > 1:
            # need to match country to authors
            
            edge_authors = {n: df_authors.iloc[n]['name'] for n in \
                            set(df_doi.n1.tolist() + df_doi.n2.tolist())}
                
            aff_authors = doi_to_affilitations[doi]['authors']
            
            edge_authors_to_affID = {}
            
            # name similarity matrix
            sim_mat = []
            edge_authors_list = [(n,edge_auth) for n, edge_auth in edge_authors.items()]
            for auth in aff_authors:
                
                sim_mat.append([jellyfish.jaro_winkler_similarity(auth['name'].lower(), edge_auth) \
                         for n,edge_auth in edge_authors_list])
                    
            # find exact matches
            sim_mat = np.array(sim_mat)
            rows = []
            cols = []
            for i, auth in enumerate(aff_authors):
                exact_match, = np.where(sim_mat[i,:] == 1.0)
                if exact_match.size == 1:
                    # check that this name does not match with other authors
                    x, = np.where(sim_mat[:,exact_match].flatten() == 1.0)
                    if x.size == 1 and x == i:
                        edge_authors_to_affID[edge_authors_list[exact_match[0]][0]] = auth['affiliationIds']
                        rows.append(i)
                        cols.append(exact_match[0])
            #remove exact matches
            edge_auths = [i for i in range(sim_mat.shape[1]) if i not in cols]
            aff_auths = [i for i in range(sim_mat.shape[0]) if i not in rows]
            
            
            # now succesively eliminate the max
            while len(edge_auths) > 0:
                sim_mat_reduced = sim_mat[aff_auths,:][:,edge_auths]
                if sim_mat_reduced.size == 0:
                    break
                else:
                    row_max, col_max = np.unravel_index(np.argmax(sim_mat_reduced),
                                                    sim_mat_reduced.shape)
                
                    edge_authors_to_affID[edge_authors_list[edge_auths[col_max]][0]] = \
                                          aff_authors[aff_auths[row_max]]['affiliationIds']
                
                    edge_auths.pop(col_max)
                    aff_auths.pop(row_max)
                
                
            affID_to_country = {aff['id'] : ','.join(final_affil_to_country.get(aff['name'],'')) \
                                    for aff in doi_to_affilitations[doi]['affiliations']}
                
            
            for n in edge_authors.keys():
                if n in edge_authors_to_affID:
                    df_doi.loc[df_doi.n1 == n,'n1_country'] = ','.join([affID_to_country[aID] for aID in edge_authors_to_affID[n]])
                    df_doi.loc[df_doi.n2 == n,'n2_country'] = ','.join([affID_to_country[aID] for aID in edge_authors_to_affID[n]])
                    
    #save results
    df_doi.to_csv(os.path.join(savedir,f'df_edge_{doi.split("/")[-1]}.csv'),sep=';')
#%%

if __name__ == '__main__':
    
    t00 = time.time()
    
    print('starting pool of {0} cpus'.format(nproc))
    with Pool(nproc) as p:
        work = p.map_async(worker, df_edges.groupby(by='doi'),
                           chunksize=1)
        data = work.get()
            
        
    print('***** Finished! in {:.2f}'.format(time.time()-t00))
    
#%% merge all results

    df_edges = pd.read_csv('../data/aps/all_journals_disamb_edges.csv.gz', index_col=0)

    #load df_dois
    files = os.listdir(savedir)
    dfs = []
    for file in files:
        dfs.append(pd.read_csv(os.path.join(savedir,file),sep=';',index_col=0,parse_dates=[4]))
        
    new_dfs = pd.concat(dfs)
    new_df_edges = df_edges.copy()
    new_df_edges.loc[df_edges.index.isin(new_dfs.index),['n1_country','n2_country']] = new_dfs[['n1_country','n2_country']]


    new_df_edges.to_csv('../data/aps/all_journals_disamb_edges_countries.csv.gz')
    
